/*
    Copyright (C) 2022 Tenable, Inc.

	Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

		http://www.apache.org/licenses/LICENSE-2.0

	Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
*/

package vulnerability

import (
	"context"
	"fmt"
	"os"
	"strings"

	containeranalysis "cloud.google.com/go/containeranalysis/apiv1"
	"github.com/GoogleCloudPlatform/docker-credential-gcr/config"
	"github.com/GoogleCloudPlatform/docker-credential-gcr/credhelper"
	"github.com/GoogleCloudPlatform/docker-credential-gcr/store"
	"github.com/google/go-containerregistry/pkg/authn"
	"github.com/google/go-containerregistry/pkg/name"
	"github.com/google/go-containerregistry/pkg/v1/remote"
	"github.com/tenable/terrascan/pkg/iac-providers/output"
	"go.uber.org/zap"
	"google.golang.org/api/iterator"
	grafeaspb "google.golang.org/genproto/googleapis/grafeas/v1"
)

const (
	kind           = "VULNERABILITY"
	gcrURL         = "gcr.io"
	gcrArtifactURL = "docker.pkg.dev"
	googleCredPath = "GOOGLE_APPLICATION_CREDENTIALS"
)

var (
	failToGetDigestMsg = "failed to get image digest for image %s : %v"
)

// gcrScanner  holds external GCR methods
type gcrScanner interface {
	newClient(context.Context) (*containeranalysis.Client, error)
	close(client *containeranalysis.Client) error
	listOccurrences(ctx context.Context, client *containeranalysis.Client, req *grafeaspb.ListOccurrencesRequest) ([]*grafeaspb.Occurrence, error)
}

// scanner implementor for gcrScanner interface
type gscanner struct{}

// GCR  gcr container registry
type GCR struct {
	scanner gcrScanner
}

func init() {
	RegisterContainerRegistry("gcr", &GCR{
		scanner: gscanner{},
	})
}

// newClient  returns new container analysis client
func (gscanner) newClient(ctx context.Context) (*containeranalysis.Client, error) {
	return containeranalysis.NewClient(ctx)
}

// close closes container analysis client
func (gscanner) close(client *containeranalysis.Client) error {
	return client.Close()
}

// listOccurrences returns all image vulnerability Occurrences
func (gscanner) listOccurrences(ctx context.Context, client *containeranalysis.Client, req *grafeaspb.ListOccurrencesRequest) ([]*grafeaspb.Occurrence, error) {
	output := client.GetGrafeasClient().ListOccurrences(ctx, req)
	occurrences := []*grafeaspb.Occurrence{}
	for {
		occ, err := output.Next()
		if err != nil {
			if err == iterator.Done {
				break
			}
			zap.S().Errorf(errorScanningMsg, "", err)
			return occurrences, err
		}

		occurrences = append(occurrences, occ)
	}
	return occurrences, nil
}

// CheckRegistry verify provided image belongs to gcr registry
func (g *GCR) checkRegistry(image string) bool {
	host := GetDomain(image)
	return strings.HasSuffix(host, gcrURL) || strings.HasSuffix(host, gcrArtifactURL)
}

// GetVulnerabilities - get vulnerabilities from gcr registry
func (g *GCR) getVulnerabilities(container output.ContainerDetails, options map[string]interface{}) (vulnerabilities []output.Vulnerability) {
	results, err := g.ScanImage(container.Image)
	if err != nil {
		zap.S().Errorf("error finding vulnerabilities for image %s : %v", container.Image, err)
		return
	}
	for _, result := range results {
		vulnerability := output.Vulnerability{}
		vulnerability.PrepareFromGCRImageScan(result)
		vulnerabilities = append(vulnerabilities, vulnerability)
	}
	return
}

// ScanImage get the image scan result from GCR registry
func (g *GCR) ScanImage(image string) (result []*grafeaspb.Occurrence, err error) {
	ctx := context.Background()

	if image, err = getImageNameCompatibleForScan(image); err != nil {
		return nil, err
	}

	projectID, err := getProjectIDFromImageName(image)
	if err != nil {
		return nil, err
	}

	resourceURL := fmt.Sprintf("https://%s", image)

	client, err := g.scanner.newClient(ctx)
	if err != nil {
		zap.S().Errorf(errorScanningMsg, image, err)
		return result, err
	}

	defer g.scanner.close(client)

	req := &grafeaspb.ListOccurrencesRequest{
		Parent: fmt.Sprintf("projects/%s", projectID),
		Filter: fmt.Sprintf("resourceUrl = %q kind = %q", resourceURL, kind),
	}

	return g.scanner.listOccurrences(ctx, client, req)

}

// getImageNameCompatibleForScan - returns image name which has digest
func getImageNameCompatibleForScan(image string) (string, error) {
	imageDetails := ImageDetails{}

	imageDetails = GetImageDetails(image, imageDetails)
	if imageDetails.Tag != "" {
		image = strings.Replace(image, colon+imageDetails.Tag, "", 1)
	}
	if imageDetails.Digest == "" {
		if imageDetails.Tag != "" {
			image = image + colon + imageDetails.Tag
		}
		digest, err := findDigestForImage(image)
		if err != nil {
			return "", err
		}
		if imageDetails.Tag != "" {
			image = strings.TrimSuffix(image, colon+imageDetails.Tag)
		}
		image = image + atTheRate + digest
	}
	return image, nil
}

// getProjectIdFromImageName -  get project id from image name
func getProjectIDFromImageName(image string) (string, error) {
	imageSlice := strings.Split(image, "/")
	if len(imageSlice) < 2 {
		errMessage := fmt.Errorf(invalidImageReferenceMsg, image)
		zap.S().Error(errMessage)
		return "", errMessage
	}
	projectID := imageSlice[1]
	return projectID, nil
}

// findDigestForImage - get digest from image name and tag
func findDigestForImage(image string) (digest string, err error) {
	zap.S().Debug("fetching digest for image %s", image)
	var nameOpts []name.Option

	ref, err := name.ParseReference(image, nameOpts...)
	if err != nil {
		zap.S().Errorf(invalidImageReferenceMsg, image)
		return
	}
	domain := ref.Context().RegistryStr()

	var remoteOpts []remote.Option

	auth, err := getCredentials(image, domain)
	if err != nil {
		return
	}
	remoteOpts = append(remoteOpts, remote.WithAuth(&auth))

	des, err := remote.Get(ref, remoteOpts...)
	if err != nil {
		errorMsg := fmt.Errorf(failToGetDigestMsg, image, err)
		zap.S().Error(errorMsg)
		return
	}

	img, err := des.Image()
	if err != nil {
		errorMsg := fmt.Errorf(failToGetDigestMsg, image, err)
		zap.S().Error(errorMsg)
		return
	}

	dig, err := img.Digest()
	if err != nil {
		errorMsg := fmt.Errorf(failToGetDigestMsg, image, err)
		zap.S().Error(errorMsg)
		return
	}
	digest = dig.String()
	return
}

// getCredentials - get the credentials from the application key
func getCredentials(image string, domain string) (auth authn.Basic, err error) {
	gcpCredPath := os.Getenv(googleCredPath)
	if gcpCredPath == "" {
		errorMsg := fmt.Errorf(failToGetDigestMsg, image, err)
		zap.S().Error(errorMsg)
		err = errorMsg
		return
	}
	credStore := store.NewGCRCredStore(gcpCredPath)

	userCfg, err := config.LoadUserConfig()
	if err != nil {
		errorMsg := fmt.Errorf(failToGetDigestMsg, image, err)
		zap.S().Error(errorMsg)
		return
	}
	helper := credhelper.NewGCRCredentialHelper(credStore, userCfg)

	username, password, err := helper.Get(domain)
	if err != nil {
		errorMsg := fmt.Errorf(failToGetDigestMsg, image, err)
		zap.S().Error(errorMsg)
		return
	}

	auth = authn.Basic{Username: username, Password: password}
	return
}
